// SPDX-FileCopyrightText: Copyright 2025-2025 深圳市同心圆网络有限公司
// SPDX-License-Identifier: GPL-3.0-only

package trace_dao

import (
	"context"
	"strings"
	"time"

	"gitcode.com/opendragonfly/df_proto_gen_go.git/trace_api"
	"github.com/dgraph-io/badger/v4"
	"github.com/patrickmn/go-cache"
	"go.opentelemetry.io/collector/pdata/pcommon"
	"go.opentelemetry.io/collector/pdata/ptrace"
	semconv "go.opentelemetry.io/otel/semconv/v1.25.0"
)

type _TraceDao struct {
	dbh              *badger.DB
	startTimeIndex   *_StartTimeIndex
	consumeTimeIndex *_ConsumeTimeIndex
	simpleTraceTable *_SimpleTraceTable
	fullTraceTable   *_FullTraceTable
	serviceTable     *_ServiceTable
	memCache         *cache.Cache

	keepTraceDay uint
}

func newTraceDao(dbh *badger.DB, keepTraceDay uint) *_TraceDao {
	return &_TraceDao{
		dbh:              dbh,
		startTimeIndex:   &_StartTimeIndex{},
		consumeTimeIndex: &_ConsumeTimeIndex{},
		simpleTraceTable: &_SimpleTraceTable{},
		fullTraceTable:   &_FullTraceTable{},
		serviceTable:     &_ServiceTable{},
		memCache:         cache.New(10*time.Minute, 1*time.Minute),

		keepTraceDay: keepTraceDay,
	}
}

func (db *_TraceDao) ConsumeTraces(ctx context.Context, td ptrace.Traces) error {
	resSpans := td.ResourceSpans()
	for i := 0; i < resSpans.Len(); i++ {
		resSpan := resSpans.At(i)
		err := db.insert(resSpan)
		if err != nil {
			return err
		}
	}
	return nil
}

func (db *_TraceDao) parseServiceInfo(attrMap pcommon.Map) (string, string) {
	serviceName := ""
	serviceVersion := ""

	value, ok := attrMap.Get(string(semconv.ServiceNameKey))
	if ok {
		serviceName = value.AsString()
	}
	value, ok = attrMap.Get(string(semconv.ServiceVersionKey))
	if ok {
		serviceVersion = value.AsString()
	}
	return serviceName, serviceVersion
}

func (db *_TraceDao) insert(resSpan ptrace.ResourceSpans) error {
	serviceName, serviceVersion := db.parseServiceInfo(resSpan.Resource().Attributes())
	doneTraceIdMap := map[string]bool{}
	for i := 0; i < resSpan.ScopeSpans().Len(); i++ {
		spanList := resSpan.ScopeSpans().At(i)
		for j := 0; j < spanList.Spans().Len(); j++ {
			span := spanList.Spans().At(j)
			db.insertCache(serviceName, serviceVersion, span)
			//检查是否需要写入数据库
			if span.ParentSpanID().IsEmpty() {
				doneTraceIdMap[span.TraceID().String()] = true
			}
		}
	}
	for doneTraceId := range doneTraceIdMap {
		err := db.persistentTrace(serviceName, serviceVersion, doneTraceId)
		if err != nil {
			return err
		}
	}
	return nil
}

func (db *_TraceDao) calcCacheKey(serviceName, serviceVersion, traceId string) string {
	return strings.Join([]string{serviceName, serviceVersion, traceId}, ":")
}

func (db *_TraceDao) insertCache(serviceName, serviceVersion string, span ptrace.Span) {
	traceId := span.TraceID().String()
	cacheKey := db.calcCacheKey(serviceName, serviceVersion, traceId)
	cacheValue, ok := db.memCache.Get(cacheKey)
	if !ok {
		cacheValue = &trace_api.TraceFullInfo{
			TraceId: traceId,
			Service: &trace_api.ServiceInfo{
				ServiceName:    serviceName,
				ServiceVersion: serviceVersion,
			},
			SpanList: []*trace_api.SpanInfo{},
		}
	}
	realCacheValue := cacheValue.(*trace_api.TraceFullInfo)
	//转换span格式
	spanKind := span.Kind()
	newKind := trace_api.SPAN_KIND_SPAN_KIND_UNSPECIFIED
	if spanKind == ptrace.SpanKindUnspecified {
		newKind = trace_api.SPAN_KIND_SPAN_KIND_UNSPECIFIED
	} else if spanKind == ptrace.SpanKindInternal {
		newKind = trace_api.SPAN_KIND_SPAN_KIND_INTERNAL
	} else if spanKind == ptrace.SpanKindServer {
		newKind = trace_api.SPAN_KIND_SPAN_KIND_SERVER
	} else if spanKind == ptrace.SpanKindClient {
		newKind = trace_api.SPAN_KIND_SPAN_KIND_CLIENT
	} else if spanKind == ptrace.SpanKindProducer {
		newKind = trace_api.SPAN_KIND_SPAN_KIND_PRODUCER
	} else if spanKind == ptrace.SpanKindConsumer {
		newKind = trace_api.SPAN_KIND_SPAN_KIND_CONSUMER
	}
	//生成attr列表
	attrList := []*trace_api.AttrInfo{}
	span.Attributes().Range(func(k string, v pcommon.Value) bool {
		attrList = append(attrList, &trace_api.AttrInfo{
			Key:   k,
			Value: v.AsString(),
		})
		return true
	})
	//生成event列表
	eventList := []*trace_api.EventInfo{}
	for i := 0; i < span.Events().Len(); i++ {
		ev := span.Events().At(i)
		evAttrList := []*trace_api.AttrInfo{}
		ev.Attributes().Range(func(k string, v pcommon.Value) bool {
			evAttrList = append(evAttrList, &trace_api.AttrInfo{
				Key:   k,
				Value: v.AsString(),
			})
			return true
		})
		eventList = append(eventList, &trace_api.EventInfo{
			Name:      ev.Name(),
			TimeStamp: ev.Timestamp().AsTime().UnixMilli(),
			AttrList:  evAttrList,
		})
	}
	newSpan := &trace_api.SpanInfo{
		SpanId:         span.SpanID().String(),
		ParentSpanId:   span.ParentSpanID().String(),
		SpanName:       span.Name(),
		StartTimeStamp: span.StartTimestamp().AsTime().UnixMilli(),
		EndTimeStamp:   span.EndTimestamp().AsTime().UnixMilli(),
		SpanKind:       newKind,
		AttrList:       attrList,
		EventList:      eventList,
	}
	realCacheValue.SpanList = append(realCacheValue.SpanList, newSpan)
	db.memCache.Set(cacheKey, realCacheValue, 0)
}

func (db *_TraceDao) persistentTrace(serviceName, serviceVersion, traceId string) error {
	cacheKey := db.calcCacheKey(serviceName, serviceVersion, traceId)
	cacheValue, ok := db.memCache.Get(cacheKey)
	if !ok {
		return nil
	}
	defer db.memCache.Delete(cacheKey)
	realCacheValue := cacheValue.(*trace_api.TraceFullInfo)
	var rootSpan *trace_api.SpanInfo
	for _, span := range realCacheValue.SpanList {
		if span.ParentSpanId == "" {
			rootSpan = span
			break
		}
	}
	if rootSpan == nil {
		return nil
	}

	txn := db.dbh.NewTransaction(true)
	defer txn.Discard()

	//写入service和rootSpanName
	err := db.serviceTable.Insert(txn, realCacheValue.Service, rootSpan.SpanName, rootSpan.StartTimeStamp, db.keepTraceDay)
	if err != nil {
		return err
	}
	//写入simple trace表
	err = db.simpleTraceTable.Insert(txn, &trace_api.TraceInfo{
		TraceId:  realCacheValue.TraceId,
		Service:  realCacheValue.Service,
		RootSpan: rootSpan,
	}, db.keepTraceDay)
	if err != nil {
		return err
	}
	//写入full trace表
	err = db.fullTraceTable.Insert(txn, realCacheValue, db.keepTraceDay)
	if err != nil {
		return err
	}
	//写入耗时索引
	timeStr := time.UnixMilli(rootSpan.StartTimeStamp).Format("2006-01-02 15")
	consumeKey := &ConsumeTimeKey{
		TimeStr:        timeStr,
		ServiceName:    realCacheValue.Service.ServiceName,
		ServiceVersion: realCacheValue.Service.ServiceVersion,
		RootSpanName:   rootSpan.SpanName,
		ConsumeTime:    uint32(rootSpan.EndTimeStamp - rootSpan.StartTimeStamp),
	}
	err = db.consumeTimeIndex.Insert(txn, consumeKey, &trace_api.TraceIdWithTime{
		TraceId:        realCacheValue.TraceId,
		StartTimeStamp: rootSpan.StartTimeStamp,
	}, db.keepTraceDay)
	if err != nil {
		return err
	}
	consumeKey.RootSpanName = ""
	err = db.consumeTimeIndex.Insert(txn, consumeKey, &trace_api.TraceIdWithTime{
		TraceId:        realCacheValue.TraceId,
		StartTimeStamp: rootSpan.StartTimeStamp,
	}, db.keepTraceDay)
	if err != nil {
		return err
	}
	//写入开始时间索引
	startKey := &StartTimeKey{
		ServiceName:    realCacheValue.Service.ServiceName,
		ServiceVersion: realCacheValue.Service.ServiceVersion,
		RootSpanName:   rootSpan.SpanName,
		StartTime:      rootSpan.StartTimeStamp,
	}
	err = db.startTimeIndex.Insert(txn, startKey, realCacheValue.TraceId, db.keepTraceDay)
	if err != nil {
		return err
	}
	startKey.RootSpanName = ""
	err = db.startTimeIndex.Insert(txn, startKey, realCacheValue.TraceId, db.keepTraceDay)
	if err != nil {
		return err
	}
	return txn.Commit()
}

func (db *_TraceDao) ListService(fromTime, toTime int64) ([]*trace_api.ServiceInfo, error) {
	txn := db.dbh.NewTransaction(false)
	defer txn.Discard()

	return db.serviceTable.ListService(txn, fromTime, toTime)
}

func (db *_TraceDao) ListRootName(serviceInfo *trace_api.ServiceInfo, fromTime, toTime int64) ([]string, error) {
	txn := db.dbh.NewTransaction(false)
	defer txn.Discard()

	return db.serviceTable.ListRootName(txn, serviceInfo, fromTime, toTime)
}

func (db *_TraceDao) ListTraceByConsumeTime(serviceInfo *trace_api.ServiceInfo,
	filterByRootSpanName bool, rootSpanName string,
	fromTime, toTime int64, limit uint32) ([]*trace_api.TraceInfo, error) {
	txn := db.dbh.NewTransaction(false)
	defer txn.Discard()

	traceIdList, err := db.consumeTimeIndex.List(txn, serviceInfo, filterByRootSpanName, rootSpanName, fromTime, toTime, limit)
	if err != nil {
		return nil, err
	}
	return db.simpleTraceTable.ListById(txn, serviceInfo, traceIdList)
}

func (db *_TraceDao) ListTraceByStartTime(serviceInfo *trace_api.ServiceInfo,
	filterByRootSpanName bool, rootSpanName string,
	fromTime, toTime int64, limit uint32) ([]*trace_api.TraceInfo, error) {
	txn := db.dbh.NewTransaction(false)
	defer txn.Discard()

	traceIdList, err := db.startTimeIndex.List(txn, serviceInfo, filterByRootSpanName, rootSpanName, fromTime, toTime, limit)
	if err != nil {
		return nil, err
	}
	return db.simpleTraceTable.ListById(txn, serviceInfo, traceIdList)
}

func (db *_TraceDao) ListTraceById(serviceInfo *trace_api.ServiceInfo, traceIdList []string) ([]*trace_api.TraceInfo, error) {
	txn := db.dbh.NewTransaction(false)
	defer txn.Discard()

	return db.simpleTraceTable.ListById(txn, serviceInfo, traceIdList)
}

func (db *_TraceDao) GetFullTrace(serviceInfo *trace_api.ServiceInfo, traceId string) (*trace_api.TraceFullInfo, error) {
	txn := db.dbh.NewTransaction(false)
	defer txn.Discard()

	return db.fullTraceTable.Get(txn, serviceInfo, traceId)
}
